{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "import IPython.display\n",
    "\n",
    "import librosa\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from collections import OrderedDict\n",
    "\n",
    "# load other modules --> repo root path\n",
    "sys.path.insert(0, \"../\")\n",
    "\n",
    "import torch\n",
    "from utils import text, audio\n",
    "from utils import remove_dataparallel_prefix\n",
    "from utils.logging import Logger\n",
    "from params.params import Params as hp\n",
    "from modules.tacotron2 import Tacotron\n",
    "from dataset.dataset import TextToSpeechDataset, TextToSpeechDatasetCollection, TextToSpeechCollate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_model(checkpoint):\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    state = torch.load(checkpoint, map_location=device)\n",
    "    hp.load_state_dict(state['parameters'])\n",
    "\n",
    "    model = Tacotron()\n",
    "    model_dict = model.state_dict()\n",
    "    pretrained_dict = remove_dataparallel_prefix(state['model'])\n",
    "    \n",
    "    missing = [k for k, _ in pretrained_dict.items() if k not in model_dict]\n",
    "    print(f'Missing model parts: {missing}')\n",
    "    \n",
    "    missing = [k for k, _ in model_dict.items() if k not in pretrained_dict]\n",
    "    print(f'Extra model parts: {missing}')\n",
    "    \n",
    "    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}\n",
    "    \n",
    "    model_dict.update(pretrained_dict) \n",
    "    model.load_state_dict(model_dict) \n",
    "    model.to(device)\n",
    "    \n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def inference(model, inputs, device=None):\n",
    "    \n",
    "    inputs = [l.rstrip().split('|') for l in inputs if l]\n",
    "\n",
    "    spectrograms = []\n",
    "    for i in inputs:\n",
    "        \n",
    "        clean_text = i[0]\n",
    "\n",
    "        if not hp.use_punctuation: \n",
    "            clean_text = text.remove_punctuation(clean_text)\n",
    "        if not hp.case_sensitive: \n",
    "            clean_text = text.to_lower(clean_text)\n",
    "        if hp.remove_multiple_wspaces: \n",
    "            clean_text = text.remove_odd_whitespaces(clean_text)\n",
    "        \n",
    "        t = torch.LongTensor(text.to_sequence(clean_text, use_phonemes=hp.use_phonemes))\n",
    "        \n",
    "        if hp.multi_language:     \n",
    "            l_tokens = i[2].split(',')\n",
    "            t_length = len(clean_text) + 1\n",
    "            l = []\n",
    "            for token in l_tokens:\n",
    "                l_d = token.split('-')\n",
    " \n",
    "                language = [0] * hp.language_number\n",
    "                for l_cw in l_d[0].split(':'):\n",
    "                    l_cw_s = l_cw.split('*')\n",
    "                    language[hp.languages.index(l_cw_s[0])] = 1 if len(l_cw_s) == 1 else float(l_cw_s[1])\n",
    "                language_length = (int(l_d[1]) if len(l_d) == 2 else t_length)\n",
    "                l += [language] * language_length\n",
    "                t_length -= language_length     \n",
    "            l = torch.FloatTensor([l])\n",
    "        else:\n",
    "            l = None\n",
    "\n",
    "        s = torch.LongTensor([hp.unique_speakers.index(i[1])]) if hp.multi_speaker else None\n",
    "\n",
    "        if device==\"cuda\" or (device is None and torch.cuda.is_available()): \n",
    "            t = t.cuda(non_blocking=True)\n",
    "            if l: l = l.cuda(non_blocking=True)\n",
    "            if s: s = s.cuda(non_blocking=True)\n",
    "     \n",
    "        spectrograms.append(model.inference(t, speaker=s, language=l).cpu().detach().numpy())\n",
    "\n",
    "    return spectrograms"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Synthesis"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint = \"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.load(checkpoint, map_location=\"cpu\")['parameters'];"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hp.decoder_language_init = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = build_model(checkpoint)\n",
    "model.eval();\n",
    "print(hp.encoder_type)\n",
    "print(hp.encoder_dimension)\n",
    "print(hp.languages)\n",
    "print(hp.unique_speakers)\n",
    "print(hp.multi_language)\n",
    "print(hp.multi_speaker)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = [\n",
    "    \"jean paul belmondo|00-de|fr\",\n",
    "    \"jean paul belmondo|00-de|de*0.1:fr*0.9\",\n",
    "    \"jean paul belmondo|00-de|de*0.2:fr*0.8\",\n",
    "    \"jean paul belmondo|00-de|de*0.3:fr*0.7\",\n",
    "    \"jean paul belmondo|00-de|de*0.4:fr*0.6\",\n",
    "    \"jean paul belmondo|00-de|de*0.5:fr*0.5\",\n",
    "    \"jean paul belmondo|00-de|de*0.6:fr*0.4\",\n",
    "    \"jean paul belmondo|00-de|de*0.7:fr*0.3\",\n",
    "    \"jean paul belmondo|00-de|de*0.8:fr*0.2\",\n",
    "    \"jean paul belmondo|00-de|de*0.9:fr*0.1\",\n",
    "    \"jean paul belmondo|00-de|de\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = []\n",
    "for s, l in [(\"00-de\",\"de\"),(\"00-fr\",\"fr\"),(\"00-nl\",\"nl\"),(\"00-zh\",\"zh\"),(\"09-ru\",\"ru\")]:\n",
    "    inputs += [\n",
    "        f\"Das schließlich realisierte projekt vereint die Wien und die Wiener Stadtbahn in einem einheitlichen Profil.|{s}|{l}\",\n",
    "        f\"Tijdens de integratie liepen veel klanten weg omdat containers en documenten verdwenen en foutieve rekeningen werden gestuurd.|{s}|{l}\",\n",
    "        f\"Совет обладает полномочиями по руководству всей операционной деятельностью Центрального банка и контролю за коммерческими оманскими банками.|{s}|{l}\",\n",
    "        f\"C'est l'un des plus beaux palais de Mistretta dont le nom dérive d'une ancienne famille seigneuriale de la ville.|{s}|{l}\",\n",
    "        f\"yán zào cì luò， xiānruǎn xiā zhǔ dòufǔ děng shípǐn de pēngdiào jìyì wèi rén suǒ chēngdào， dà cì， bóké mǐ， gāoxiè， zhūtóu zòngděng tǔtèchǎn míngwénxiáěr。|{s}|{l}\",\n",
    "    ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = [\n",
    "    \"Charles André Joseph Marie de Gaulle war ein französischer General und Staatsmann.|00-de|fr-36,de\",\n",
    "    \"Unter anderem förderte er dank Pierre Brossolette und besonders Jean Moulin die Résistance.|00-de|de-31,fr-18,de-15,fr-11,de-5,fr-10,de\",\n",
    "    \"Im April ersetzte de Gaulle den Premierminister Michel Debré durch Georges Pompidou.|00-de|de-18,fr-9,de-21,fr-12,de-7,fr-16,de\",\n",
    "    \"Jean Marie Bastien-Thiry war mit dessen Algerien-Politik nicht länger einverstanden.|00-de|fr-24,de\",\n",
    "    \"Sein Name wurde auch dem gegenwärtig letzten französischen Flugzeugträger, der Charles de Gaulle gegeben.|00-de|de-79,fr-17,de\",\n",
    "    \"Die römisch-katholische Kirche Notre-Dame de Paris ist die Kathedrale des Erzbistums Paris.|00-de|de-31,fr-19,de\",\n",
    "    \"Neben Molière oder Balzac gilt er vielen Franzosen als ihr größter Autor überhaupt.|00-de|de-6,fr-7,de-6,fr-6,de\",\n",
    "    \"Auf der Geburtsanzeige seines dritten Kindes nennt Hugo sich stolz Baron.|00-de|de-51,fr-4,de\",\n",
    "    \"Bei der Abfassung hatte Chateaubriand aber sicher auch opportunistische Motive:|00-de|de-24,fr-13,de\",\n",
    "    \"Er musste die Briefe, die er täglich an Juliette Récamier schrieb, diktieren und unterschrieb sie mit zitternder Hand.|00-de|de-40,fr-17,de\",\n",
    "    \"Die noch bewohnte Cité de Carcassonne wird von einem doppelten Mauerring umschlossen.|00-de|de-18,fr-19,de\",\n",
    "    \"Flussaufwärts sind die nächsten Brücken der Pont de Tancarville und der Pont de Brotonne.|00-de|de-44,fr-19,de-9,fr-16,de\",\n",
    "    \"Die Bevölkerung von Marseille war seit jeher stolz und unabhängig und im ganzen Land dafür bekannt,|00-de|de-20,fr-9,de\",\n",
    "    \"François Hollande ist ein französischer Politiker der Sozialistischen Partei und war Staatspräsident der Französischen Republik.|00-de|fr-17,de\",\n",
    "    \"Sein Amtssitz ist der Élysée-Palast in Paris.|00-de|de-22,fr-13,de\",\n",
    "    \"Versailles ist eine französische Stadt in der Region Île-de-France.|00-de|fr-10,de-43,fr-13,de\",\n",
    "    \"Die waldreichen Talhänge bilden mit dem Bois des Fonds des Maréchaux.|00-de|de-40,fr-28,de\",\n",
    "    \"Zu Beginn seiner Herrschaft hielt sich der Monarch allerdings jeweils nur für kurze Zeit in Versailles auf.|00-de|de-92,fr-10,de\",\n",
    "    \"In Versailles benötigten sie häufig nur angemietete Unterkünfte, wovon insbesondere die Gastwirtschaft profitierte.|00-de|de-3,fr-10,de\",\n",
    "    \"Nachdem der König der Nationalversammlung jedoch den Zutritt in das Hôtel des Menus Plaisirs verweigert hatte,|00-de|de-68,fr-24,de\"\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "generated_spectrograms = inference(model, inputs, \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for i, s in enumerate(generated_spectrograms):\n",
    "    s = audio.denormalize_spectrogram(s, not hp.predict_linear)\n",
    "    w = audio.inverse_spectrogram(s, not hp.predict_linear)\n",
    "    a = IPython.display.Audio(data=w, rate=hp.sample_rate)\n",
    "    IPython.display.display(a)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
